import sys
sys.path.append('./xds/xds_python/')

from torch import nn
from torch.nn import LSTM
from sklearn.metrics import r2_score

from preprocess.data_loader import monkey_spike_generator, get_monkey_spike_size, get_all_input_spike
from xds_utils import find_target_dir

import torch
import torchmetrics
import math
import torch.nn.functional as F

class cursor_Predictor(nn.Module):
    def __init__(self, input_size, h_dim, lstm_layer, pos_dim):
        super(cursor_Predictor, self).__init__()
        # Variables
        self.input_size = input_size
        self.h_dim = h_dim
        self.lstm_layer = lstm_layer
        self.pos_dim = pos_dim

        # Decoder
        self.base_decoder = LSTM(input_size=self.input_size, hidden_size=self.h_dim, num_layers=self.lstm_layer, batch_first=True)
        self.fc_classifier = nn.Sequential(nn.BatchNorm1d(self.h_dim),
                                           nn.Linear(self.h_dim, self.pos_dim))
        # Loss
        self.cla_criterion = nn.CrossEntropyLoss()
        self.mse_regression = nn.MSELoss()
        self.acc_metrics = torchmetrics.Accuracy(task="multiclass", num_classes=self.pos_dim)

    def get_dir_accuracy(self, y_pred, target_dir):
        target_dir_gt = torch.argmax(target_dir, dim=1)
        acc = self.acc_metrics(y_pred, target_dir_gt)
        return acc

    def forward(self, input, target_dir):

        # lstm decoder (extractor)
        _, (_, src_feature) = self.base_decoder(input)
        y_pred = self.fc_classifier(src_feature[0])

        mse_loss = self.mse_regression(y_pred, target_dir)

        '''
        cla_loss = self.cla_criterion(y_pred, target_dir)

        y_prob = F.softmax(y_pred, dim=1)
        acc = self.get_dir_accuracy(y_prob, target_dir)
        '''

        # return y_prob, cla_loss, acc
        return y_pred, mse_loss

def covert_to_target_idx(day_target_dir):
    for di in range(day_target_dir.shape[0]):
        day_target_dir[di] = int(day_target_dir[di] / 45 + 3)
    return day_target_dir

# training cursor direction prediction on source cursor position
def train_cursor_predictor(model, device, day_cursor_pos_xy, dataset_config, hyparams_config):

    batch_size_limit = 10

    # get cursor direction
    # day_target_dir = find_target_dir(day_cursor_pos_xy, [-135, -90, -45, 0, 45, 90, 135, 180])
    # day_target_idx = covert_to_target_idx(day_target_dir)

    # training data generator
    shuffle_flag, pos_flag = False, True
    src_train_generator = monkey_spike_generator(day_spike=day_cursor_pos_xy,
                                                 day_cursor=day_cursor_pos_xy,
                                                 window_size=dataset_config.cursor_window_size,
                                                 batch_size=dataset_config.batch_size,
                                                 is_shuffle=shuffle_flag,
                                                 is_pos=pos_flag)

    global_step = 0
    # total_training_step = 1000
    optimizer = torch.optim.Adam(model.parameters(), lr=hyparams_config.learning_rate,
                                 weight_decay=hyparams_config.weight_decay)

    while global_step < hyparams_config.cursor_training_steps:
        model.train()

        train_batch_cursor, train_batch_target_dir, _ = src_train_generator.__next__()

        if train_batch_cursor.shape[0] < batch_size_limit:
            continue

        train_x = torch.tensor(train_batch_cursor).to(device)
        train_y = torch.tensor(train_batch_target_dir).to(device)

        # _, cla_loss, _ = model.forward(train_x, train_y)
        # print("Training: classification Loss = %f" % (cla_loss))
        _, mse_loss = model.forward(train_x, train_y)
        print("Training: MSE Loss = %f" % (mse_loss))

        optimizer.zero_grad()
        mse_loss.backward()
        # cla_loss.backward()
        optimizer.step()

        global_step += 1
    
    return

def get_cursor_position_structure(model, device, pred_cursor_pos, gt_spike, dataset_config):
    cur_idx = 0
    # model.forward()
    day_pred_cursor = []
    for ti in range(0, len(gt_spike)):
        trial_len = gt_spike[ti].shape[0] - dataset_config.window_size + 1
        day_pred_cursor.append(pred_cursor_pos[cur_idx:cur_idx+trial_len, :])
        cur_idx += trial_len
    
    # get cursor direction
    shuffle_flag, pos_flag = False, True
    # day_pred_target_dir = find_target_dir(day_pred_cursor, [-135, -90, -45, 0, 45, 90, 135, 180])
    # day_pred_target_idx = covert_to_target_idx(day_pred_target_dir)
    pred_test_generator = monkey_spike_generator(day_spike=day_pred_cursor, 
                                                 day_cursor=day_pred_cursor,
                                                 window_size=dataset_config.cursor_window_size,
                                                 batch_size=dataset_config.batch_size,
                                                 is_shuffle=shuffle_flag,
                                                 is_pos=pos_flag)
    pred_test_size = get_monkey_spike_size(day_spike=day_pred_cursor,
                                           day_cursor=day_pred_cursor,
                                           window_size=dataset_config.cursor_window_size,
                                           is_pos=pos_flag)
    
    total_r2_score = 0.0
    pred_test_epoch = int(math.ceil(pred_test_size / float(dataset_config.batch_size)))
    for _ in range(pred_test_epoch):
        model.eval()
        with torch.no_grad():
            test_batch_pred_x, test_batch_pred_y, _ = pred_test_generator.__next__()

            test_x = torch.tensor(test_batch_pred_x).to(device)
            test_y = torch.tensor(test_batch_pred_y).to(device)

            test_pred_y, _  = model.forward(test_x, test_y)
            test_pred_y_arr, test_y_arr = test_pred_y.detach().cpu().numpy(), test_y.detach().cpu().numpy()

            r2_score_batch = r2_score(test_y_arr, test_pred_y_arr)
            total_r2_score += r2_score_batch
    
    mean_r2_score = total_r2_score / pred_test_epoch

    return mean_r2_score

def fine_tune_cursor_decoder(model, cur_predictor, device, tgt_day_spike_train, tgt_day_cursor_train, dataset_config, optimizer_cur):

    tgt_day_spike_train_format, _ = get_all_input_spike(tgt_day_spike_train, tgt_day_cursor_train, dataset_config.window_size)
    tgt_day_spike_train_format_tensor = torch.tensor(tgt_day_spike_train_format).to(device)

    pos_latent_tgt = model.get_model_pos_latent(src_x=tgt_day_spike_train_format_tensor, domain_flag=False, train_flag=False)
    pos_latent_tgt = pos_latent_tgt.detach().cpu().numpy()

    # generate training batches
    # transform from format to trials
    cur_idx = 0
    day_pos_latent = []
    for ti in range(len(tgt_day_spike_train)):
        trial_len = tgt_day_spike_train[ti].shape[0] - dataset_config.window_size + 1
        day_pos_latent.append(pos_latent_tgt[cur_idx:cur_idx+trial_len, :])
        cur_idx += trial_len
    
    shuffle_flag, pos_flag = False, True
    pred_train_generator = monkey_spike_generator(day_spike=day_pos_latent,
                                                  day_cursor=day_pos_latent,
                                                  window_size=dataset_config.cursor_window_size,
                                                  batch_size=dataset_config.batch_size,
                                                  is_shuffle=shuffle_flag,
                                                  is_pos=pos_flag)
    pred_train_size = get_monkey_spike_size(day_spike=day_pos_latent,
                                            day_cursor=day_pos_latent,
                                            window_size=dataset_config.cursor_window_size,
                                            is_pos=pos_flag)
    
    pred_train_epoch = int(math.ceil(pred_train_size / float(dataset_config.batch_size)))

    # fix cursor predictor
    for _, param in cur_predictor.named_parameters():
        param.requires_grad = False
    
    batch_size_limit = 10
    for _ in range(pred_train_epoch):
        train_batch_pred_x, train_batch_pred_y, _ = pred_train_generator.__next__()
        
        train_batch_pred_x = torch.tensor(train_batch_pred_x).to(device)
        train_batch_pred_y = torch.tensor(train_batch_pred_y).to(device)

        batch_size = train_batch_pred_x.shape[0]
        if batch_size < batch_size_limit:
            pred_train_epoch -= 1
            continue

        train_batch_pred_x = torch.reshape(train_batch_pred_x, (-1, model.latent_dim))
        train_batch_pred_cur = model.pos_read_out(train_batch_pred_x)
        train_batch_pred_cur = torch.reshape(train_batch_pred_cur, (batch_size, -1, model.pos_dim))

        train_batch_pred_gt = model.pos_read_out(train_batch_pred_y)
        _, mse_loss = cur_predictor.forward(train_batch_pred_cur, train_batch_pred_gt)

        print("target pred MSE Loss = %f" % (mse_loss))

        optimizer_cur.zero_grad()
        mse_loss.backward()
        optimizer_cur.step()

    return